Skip to contents
library(ebm)


data("Hitters", package = "ISLR2")

# Remove rows with missing response values
head(hitters <- Hitters[!is.na(Hitters$Salary), ])
#>                   AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun
#> -Alan Ashby         315   81     7   24  38    39    14   3449   835     69
#> -Alvin Davis        479  130    18   66  72    76     3   1624   457     63
#> -Andre Dawson       496  141    20   65  78    37    11   5628  1575    225
#> -Andres Galarraga   321   87    10   39  42    30     2    396   101     12
#> -Alfredo Griffin    594  169     4   74  51    35    11   4408  1133     19
#> -Al Newman          185   37     1   23   8    21     2    214    42      1
#>                   CRuns CRBI CWalks League Division PutOuts Assists Errors
#> -Alan Ashby         321  414    375      N        W     632      43     10
#> -Alvin Davis        224  266    263      A        W     880      82     14
#> -Andre Dawson       828  838    354      N        E     200      11      3
#> -Andres Galarraga    48   46     33      N        E     805      40      4
#> -Alfredo Griffin    501  336    194      A        W     282     421     25
#> -Al Newman           30    9     24      N        E      76     127      7
#>                   Salary NewLeague
#> -Alan Ashby        475.0         N
#> -Alvin Davis       480.0         A
#> -Andre Dawson      500.0         N
#> -Andres Galarraga   91.5         N
#> -Alfredo Griffin   750.0         A
#> -Al Newman          70.0         A
# Fit a default EBM regressor
fit <- ebm(Salary ~ ., data = hitters, objective = "rmse")
fit  # still need to implement print() and summary() methods
#> ExplainableBoostingRegressor(early_stopping_tolerance=0)
head(predict(fit, newdata = hitters))
#> [1] 489.4548 626.1970 870.3430 169.8797 659.6543 270.5382
head(predict(fit, newdata = hitters, se.fit = TRUE))
#>          [,1]      [,2]
#> [1,] 489.4548  53.92471
#> [2,] 626.1970  98.39167
#> [3,] 870.3430 241.04199
#> [4,] 169.8797  43.14385
#> [5,] 659.6543  54.02381
#> [6,] 270.5382 155.71903
plot(fit, display = "markdown")
plot(fit, term = "Years", display = "markdown")
fit$monotonize("Years", increasing = FALSE)
#> ExplainableBoostingRegressor(early_stopping_tolerance=0)
plot(fit, term = "Years", display = "markdown")
# Understand an individual prediction
x <- subset(hitters, select = -Salary)[1L, ]  # use first observation
plot(fit, local = TRUE, X = x, y = hitters$Salary[1L], display = "markdown")